-
Notifications
You must be signed in to change notification settings - Fork 191
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[low-bit optim] Add coat for float8 optimizer #1231
base: main
Are you sure you want to change the base?
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/ao/1231
Note: Links to docs will display an error until the docs builds have been completed. This comment was automatically generated by Dr. CI and updates every 15 minutes. |
I was thinking you can just add a flag to the current |
i have added the flag for optimstatefp8. could you verify its right? |
I think this requires a bit more work. You need to verify that you can create an optimizer with this (add test to https://github.com/pytorch/ao/blob/main/test/prototype/test_low_bit_optim.py) as well do some short training runs for sanity checks (using https://github.com/pytorch/ao/blob/main/benchmarks/benchmark_low_bit_adam.py). I think for merging the PR, we should wait for the official code release to check numeric against them. If you don't mind, we can discuss more details in GPU-MODE discord group https://discord.gg/gpumode. Just create a thread under torchao and tag me in (@gau.nernst) |
I understand the situation for merging the PR. Will be glad to work on working on this issue. creating thread in gpumode |
4c45349
to
7be5a6b
Compare
…skip marker to within the function.
* Show a8wxdq load error only when the quant is used * Update Error check
This reverts commit 0bbba59.
self.block_size = codes.numel() // scale.numel() | ||
self.sqrt_minmax_exp = sqrt_minmax_exp | ||
|
||
def __tensor_flatten__(self): | ||
return self.tensor_attrs, [] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
When k
and sqrt_minmax_exp
is not None, you need to return them here (in __tensor_flatten__()
) also.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
should i pass them instead of empty array?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The first returned value (currently self.tensor_attrs
) is a list of strings containing the names of tensor attributes. In this case, when there is no dynamic range extension, it's just "codes", "scale"
. However, when there is dynamic range extension, you need to also add "k", "sqrt_minmax_exp"
. IIRC, when they are None, you are not supposed to include them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you for explaining it. i have added them. looking forward to your approval
dev-requirements.txt
Outdated
@@ -21,8 +21,7 @@ lm_eval | |||
diskcache | |||
pycocotools | |||
tqdm | |||
|
|||
# Custom CUDA Extensions | |||
git+https://github.com/NVlabs/COAT.git#subdirectory=coat/optimizer/kernels # Custom CUDA Extensions |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Don't add this. CPU runner will fail to build CUDA extension. We will just test this locally.
This is a Work in Progress PR for #1190.
As a draft PR, I have followed the first piece of advice by @gau-nernst of "extending OptimStateFp8". Have created a separate Dynamic Range Function Instead of creating a different quantize_fp8 method as it will be applied before quantization to achieve larger representation range of float8 datatypes and the class will be storing value k to inverse the it after dequantization.
Requirements:
TBA
Additional Code/logic Added:
TBA
Logic/Code changes to existing codebase:
TBA
Outcome:
TBA
Scope of Usage:
TBA
Example
TBA
Changes:
Benchmarks
Parameters
lr
)amp
)optim
)Results